import json
import types
import unittest.mock
from pathlib import Path
from typing import Dict, Iterable, List, Union

import avro.schema
import click
from avrogen import write_schema_files


def load_schema_file(schema_file: str) -> str:
    with open(schema_file) as f:
        raw_schema_text = f.read()

    redo_spaces = json.dumps(json.loads(raw_schema_text), indent=2)
    return redo_spaces


def merge_schemas(schemas: List[str]) -> str:
    # Combine schemas.
    schemas_obj = [json.loads(schema) for schema in schemas]
    merged = ["null"] + schemas_obj

    # Deduplicate repeated names.
    def Register(self, schema):
        if schema.fullname in self._names:
            # print(f"deduping {schema.fullname}")
            pass
        else:
            self._names[schema.fullname] = schema

    with unittest.mock.patch("avro.schema.Names.Register", Register):
        cleaned_schema = avro.schema.SchemaFromJSONData(merged)

    # Convert back to an Avro schema JSON representation.
    class MappingProxyEncoder(json.JSONEncoder):
        def default(self, obj):
            if isinstance(obj, types.MappingProxyType):
                return dict(obj)
            return json.JSONEncoder.default(self, obj)

    out_schema = cleaned_schema.to_json()
    encoded = json.dumps(out_schema, cls=MappingProxyEncoder, indent=2)
    return encoded


autogen_header = """# flake8: noqa

# This file is autogenerated by /metadata-ingestion/scripts/avro_codegen.py
# Do not modify manually!

# fmt: off
"""
autogen_footer = "# fmt: on\n"


def suppress_checks_in_file(filepath: Union[str, Path]) -> None:
    """
    Adds a couple lines to the top of an autogenerated file:
        - Comments to suppress flake8 and black.
        - A note stating that the file was autogenerated.
    """

    with open(filepath, "r+") as f:
        contents = f.read()

        f.seek(0, 0)
        f.write(autogen_header)
        f.write(contents)
        f.write(autogen_footer)


load_schema_method = """
import functools
import pathlib

def _load_schema(schema_name: str) -> str:
    return (pathlib.Path(__file__).parent / f"{schema_name}.avsc").read_text()
"""
individual_schema_method = """
@functools.lru_cache(maxsize=None)
def get{schema_name}Schema() -> str:
    return _load_schema("{schema_name}")
"""


def make_load_schema_methods(schemas: Iterable[str]) -> str:
    return load_schema_method + "".join(
        individual_schema_method.format(schema_name=schema) for schema in schemas
    )


@click.command()
@click.argument("schema_files", type=click.Path(exists=True), nargs=-1, required=True)
@click.argument("outdir", type=click.Path(), required=True)
def generate(schema_files: List[str], outdir: str) -> None:
    schemas: Dict[str, str] = {}
    for schema_file in schema_files:
        schema = load_schema_file(schema_file)
        schemas[Path(schema_file).stem] = schema

    merged_schema = merge_schemas(list(schemas.values()))

    write_schema_files(merged_schema, outdir)
    with open(f"{outdir}/__init__.py", "w"):
        # Truncate this file.
        pass

    # Save raw schema files in codegen as well.
    schema_save_dir = Path(outdir) / "schemas"
    schema_save_dir.mkdir()
    for schema_out_file, schema in schemas.items():
        (schema_save_dir / f"{schema_out_file}.avsc").write_text(schema)

    # Add load_schema method.
    with open(schema_save_dir / "__init__.py", "a") as schema_dir_init:
        schema_dir_init.write(make_load_schema_methods(schemas.keys()))

    # Add headers for all generated files
    generated_files = Path(outdir).glob("**/*.py")
    for file in generated_files:
        suppress_checks_in_file(file)


if __name__ == "__main__":
    generate()
